from typing import Optional

from starkware.cairo.lang.compiler.ast.ast_objects_test_utils import remove_parentheses
from starkware.cairo.lang.compiler.ast.parentheses_expr_wrapper import ParenthesesExpressionWrapper
from starkware.cairo.lang.compiler.parser import parse_expr


def test_add_and_format_parentheses():
    """
    Tests that format() / parentheses_wrapper add parentheses where required.
    """
    parentheses_wrapper = ParenthesesExpressionWrapper()

    def test_expression(
        expr_str: str,
        expected_formatted_str: Optional[str] = None,
        keep_original_parentheses: bool = False,
    ):
        """
        Removes all ExprParentheses from the parsed expression tree, then formats it using both
        the format method and the particles generated by the parentheses wrapper. If
        keep_original_parentheses is True, does not remove parentheses from the parsed expression.
        Compares the results to the expected formatted string, by default equal to expr_str.
        Also compares the AST obtained from removing and adding parentheses to the AST parsed from
        the expected formatted string.
        """
        parsed_expr = parse_expr(expr_str)
        if not keep_original_parentheses:
            parsed_expr = remove_parentheses(parsed_expr)
        if expected_formatted_str is None:
            expected_formatted_str = expr_str

        formatted_expr_str = parsed_expr.format()
        assert formatted_expr_str == expected_formatted_str

        parenthesized_expr = parentheses_wrapper.visit(parsed_expr)
        assert parenthesized_expr == parse_expr(expected_formatted_str)
        particles_expr_str = "".join(str(p) for p in parenthesized_expr.get_particles())
        assert particles_expr_str == expected_formatted_str

    test_expression("(a + b) * (c - d) * e * f")
    test_expression("x - (a + b) - (c - d) - e * f")
    test_expression("a * b / (c * d)")
    test_expression("a + b + c - d + e * f")
    test_expression("-(a + b + c)")
    test_expression("&(a + b)")
    test_expression("(a * b) ** (c * d)")
    test_expression("((a ** b) ** c) ** d")
    test_expression("a ** (&b)")
    test_expression("new (a + b)")
    test_expression("(new a) * (new b)")
    test_expression("-&(-(new (-a)))")

    # Test that parentheses are added in cases where not strictly necessary for parsing.
    test_expression("a + -b + c", "a + (-b) + c")
    test_expression("a ** b ** c ** d", "a ** (b ** (c ** d))")

    # Test that parentheses are added to non-atomized Dot, Subscript, and NewOperator expressions.
    test_expression("(x * y).z")
    test_expression("(-x).y")
    test_expression("(&x).y")
    test_expression("(new x).y")
    test_expression("(x * y)[z]")
    test_expression("(-x)[y]")
    test_expression("(&x)[y]")
    test_expression("(new x)[y]")

    # Test that removed parentheses are not added when unnecessary, and spaces are added.
    test_expression("&(x.y)", "&x.y")
    test_expression("-(x.y)", "-x.y")
    test_expression("new (x.y)", "new x.y")
    test_expression("(x.y)*z", "x.y * z")
    test_expression("x-(y.z)", "x - y.z")
    test_expression("([x].y).z", "[x].y.z")
    test_expression("&(x[y])", "&x[y]")
    test_expression("-(x[y])", "-x[y]")
    test_expression("new (x[y])", "new x[y]")
    test_expression("(x[y])*z", "x[y] * z")
    test_expression("x-(y[z])", "x - y[z]")
    test_expression("(([x][y])[z])", "[x][y][z]")
    test_expression("x[(y+z)]", "x[y + z]")
    test_expression("[((x+y) + z)]", "[x + y + z]")
    test_expression("x + (&y)", "x + &y")
    test_expression("x * y / (z ** w)", "x * y / z ** w")

    # Test that redundant parentheses are not changed when present.
    test_expression("(a * (b + c))", keep_original_parentheses=True)
    test_expression("((a * ((b + c))))", keep_original_parentheses=True)
    test_expression("(x + y)[z]", keep_original_parentheses=True)
